import torch
import numpy as np
import os
import click
import wandb
from bgflow.utils import (
    IndexBatchIterator,
)
from bgflow import (
    DiffEqFlow,
    BoltzmannGenerator,
    MeanFreeNormalDistribution,
    BlackBoxDynamics,
)
import ot as pot
from eq_ot_flow.estimator import BruteForceEstimatorFast
from eq_ot_flow.LJ import LennardJonesPotential
import json

from path_grad_helpers import (
    HutchinsonEstimatorDifferentiable,
    AugmentedAdjointDyn,
    path_gradient,
    path_gradient,
    train_loop,
    device,
    load_weights,
    fm_train_step_ot,
)
from LJ_utils import (
    get_data,
    get_dynamics,
)


@click.command()
@click.option("--n_training_data", default=100000)
@click.option("--n_batch", default=1024)
@click.option("--n_epochs", default=1001)
@click.option("--n_holdout", default=500000)
@click.option("--lr", default=5e-4)
@click.option("--n_knots_hutch", default=20)
@click.option(
    "--training_kind",
    default="fm",
    type=click.Choice(["fm", "path"]),
    help="How do you want to train the model?",
)
@click.option("--n_particles", default="13", type=click.Choice(["13", "55"]))
@click.option("--data_path", default="/data")
@click.option("--chkpt_path", default=None)
@click.option("--save_during_training", default=0)
@click.option("--grad_clipping", default=False)
def main(
    n_training_data=100000,
    n_batch=1024,
    n_epochs=1001,
    n_holdout=500000,
    lr=5e-4,
    n_knots_hutch=20,
    training_kind="fm",
    n_particles="13",
    data_path="/data",
    chkpt_path=None,
    save_during_training=0,
    grad_clipping=False,
):
    n_particles = int(n_particles)
    dim = n_particles * 3

    #  capture a dictionary of hyperparameters with config
    config = {
        "n_training_data": n_training_data,
        "n_batch": n_batch,
        "n_epochs": n_epochs,
        "n_holdout": n_holdout,
        "lr": lr,
        "n_knots_hutch": n_knots_hutch,
        "training_kind": training_kind,
        "chkpt_path": chkpt_path,
        "save_during_training": save_during_training,
        "grad_clipping": grad_clipping,
    }

    wandb.init(project=f"PathGradFlowMatching-LJ{n_particles}", config=config)

    target = LennardJonesPotential(
        dim, n_particles, eps=1.0, rm=1, oscillator_scale=1, two_event_dims=False
    )
    if chkpt_path is not None:

        log_folder = f"{chkpt_path.replace('.pt','')}-FT-{n_batch}-{training_kind}-{n_epochs}-{lr}-{n_knots_hutch}"
    else:
        log_folder = f"models/LJ{n_particles}_{n_training_data}_{n_batch}-{training_kind}-{n_epochs}-{lr}-knots{n_knots_hutch}"

    print(f"Creating folder {log_folder}")
    os.mkdir(f"{log_folder}")
    json.dump(config, open(f"{log_folder}/config.json", "w"))

    print("Loading data")
    data = get_data(data_path, n_particles)

    np.random.seed(0)
    idx = np.random.choice(np.arange(len(data)), len(data), replace=False)
    data_smaller = data[idx[:n_training_data]]
    print(f"Dataset size {data_smaller.shape}")

    # now set up a prior
    prior = MeanFreeNormalDistribution(dim, n_particles, two_event_dims=False).to(
        device
    )

    print("Building Flow")
    # Build the Boltzmann Generator
    net_dynamics = get_dynamics(n_particles)

    bb_dynamics = BlackBoxDynamics(
        dynamics_function=net_dynamics, divergence_estimator=BruteForceEstimatorFast()
    )
    flow = DiffEqFlow(dynamics=bb_dynamics)
    # having a flow and a prior, we can now define a Boltzmann Generator

    bg = BoltzmannGenerator(prior, flow, target.to(device))

    if chkpt_path is not None:
        load_weights(bg, chkpt_path)

    print("Setting up training loop")
    batch_iter = IndexBatchIterator(len(data_smaller), n_batch)

    optim = torch.optim.Adam(bg.parameters(), lr=lr)

    sigma = 0.01

    def batches():
        for idxs in batch_iter:
            yield data_smaller[idxs].to(device)

    if training_kind == "fm":
        print("Setting up fm trainer")
        fm_trainer = lambda x1: fm_train_step_ot(x1, prior, bg, pot, sigma=sigma)
    else:
        print("Setting up Path grads")
        path_grad_dynamics = AugmentedAdjointDyn(
            BlackBoxDynamics(
                dynamics_function=net_dynamics,
                divergence_estimator=HutchinsonEstimatorDifferentiable(),
                # divergence_estimator=BruteForceEstimatorFastDifferentiable()
            )
        )
        flow_hutch = DiffEqFlow(
            dynamics=path_grad_dynamics,
            integrator="rk4",
            n_time_steps=n_knots_hutch,
        )

        # Here bg is not used, since for training we use a different integrator/adjoint ode
        fm_trainer = lambda x1: path_gradient(x1, prior, target, flow_hutch)

    save_callback = None
    if save_during_training > 0:

        def save_callback(epoch):
            if (epoch % save_during_training) == 0:
                torch.save(bg.state_dict(), f"{log_folder}/chkpt-{epoch}.pt")

    print("Starting training")
    train_loop(
        n_epochs,
        bg,
        fm_trainer,
        batches,
        optim,
        save_callback=save_callback,
        grad_clipping=grad_clipping,
    )
    torch.save(bg.state_dict(), f"{log_folder}/chkpt.pt")


if __name__ == "__main__":
    main()
